/* * Copyright 2010 Brian S O'Neill * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.cojen.dirmi.io; import java.io.IOException; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketAddress; import java.net.SocketTimeoutException; import java.util.concurrent.TimeUnit; import java.security.AccessControlContext; import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import javax.net.ssl.SSLException; import org.cojen.dirmi.ClosedException; import org.cojen.dirmi.RejectedException; import org.cojen.dirmi.RemoteTimeoutException; import org.cojen.dirmi.util.Timer; /** * Implements an acceptor using TCP/IP. * * @author Brian S O'Neill */ abstract class SocketChannelAcceptor implements ChannelAcceptor { static final int LISTEN_BACKLOG = 1000; private final IOExecutor mExecutor; private final SocketAddress mLocalAddress; private final ServerSocket mServerSocket; private final AccessControlContext mContext; private final CloseableGroup<Channel> mAccepted; volatile boolean mAnyAccepted; public SocketChannelAcceptor(IOExecutor executor, SocketAddress localAddress) throws IOException { this(executor, localAddress, new ServerSocket()); } public SocketChannelAcceptor(IOExecutor executor, SocketAddress localAddress, ServerSocket serverSocket) throws IOException { if (executor == null) { throw new IllegalArgumentException("Must provide an executor"); } if (serverSocket == null) { throw new IllegalArgumentException("Must provide a server socket"); } mExecutor = executor; mServerSocket = serverSocket; serverSocket.setReuseAddress(true); serverSocket.bind(localAddress, LISTEN_BACKLOG); mLocalAddress = serverSocket.getLocalSocketAddress(); mContext = AccessController.getContext(); mAccepted = new CloseableGroup<Channel>(); } @Override public Channel accept() throws IOException { return accept(-1, null); } @Override public synchronized Channel accept(long timeout, TimeUnit unit) throws IOException { mAccepted.checkClosed(); if (timeout < 0) { mServerSocket.setSoTimeout(0); } else { long millis = unit.toMillis(timeout); if (millis <= 0) { throw new RemoteTimeoutException(timeout, unit); } else if (millis > Integer.MAX_VALUE) { mServerSocket.setSoTimeout(0); } else { mServerSocket.setSoTimeout((int) millis); } } Socket socket; try { socket = acceptSocket(); } catch (SocketTimeoutException e) { throw new RemoteTimeoutException(timeout, unit); } catch (IOException e) { mAccepted.checkClosed(); throw e; } socket.setTcpNoDelay(true); Channel channel = createChannel(SocketChannel.toSimpleSocket(socket)); channel.register(mAccepted); return channel; } @Override public Channel accept(Timer timer) throws IOException { mAccepted.checkClosed(); return accept(RemoteTimeoutException.checkRemaining(timer), timer.unit()); } @Override public void accept(final Listener listener) { try { mExecutor.execute(new Runnable() { public void run() { if (mAccepted.isClosed()) { listener.closed(new ClosedException()); return; } Channel channel; try { try { channel = accept(); mAnyAccepted = true; } catch (SSLException e) { if (!mAnyAccepted && e.getClass() == SSLException.class) { // General SSL exception upon first accept // indicates SSL subsystem is not configured // correctly. close(); } throw e; } } catch (IOException e) { if (mAccepted.isClosed()) { listener.closed(e); } else { listener.failed(e); } return; } listener.accepted(channel); } }); } catch (RejectedException e) { listener.rejected(e); } } @Override public void close() { mAccepted.close(); try { mServerSocket.close(); } catch (IOException e) { // Ignore. } } @Override public String toString() { return "ChannelAcceptor {localAddress=" + mLocalAddress + '}'; } @Override public final SocketAddress getLocalAddress() { return mLocalAddress; } protected IOExecutor executor() { return mExecutor; } abstract Channel createChannel(SimpleSocket socket) throws IOException; private Socket acceptSocket() throws IOException { try { return AccessController.doPrivileged(new PrivilegedExceptionAction<Socket>() { public Socket run() throws IOException { return mServerSocket.accept(); } }, mContext); } catch (PrivilegedActionException e) { throw (IOException) e.getCause(); } } }